/**
 * \file dnn/test/naive/dct.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#include "megdnn/oprs/nn.h"
#include "test/common/checker.h"
#include "test/common/dct_ref.h"
#include "test/common/rng.h"
#include "test/common/tensor.h"
#include "test/naive/fixture.h"

namespace megdnn {
namespace test {

TEST_F(NAIVE, DCT) {
    Checker<DctChannelSelectForward> checker(
            handle(),
            /* check_dispatch */ false);
    DctChannelSelectForward::Param param;

    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {1, 1, 16, 16}, dtype::Uint8(),
                            {87,  155, 59,  161, 24,  200, 58,  3,   40,  43,  156, 7,
                             176, 232, 226, 78,  73,  236, 185, 109, 196, 169, 62,  32,
                             167, 180, 96,  157, 101, 53,  150, 47,  26,  238, 218, 210,
                             204, 236, 249, 111, 16,  35,  169, 204, 117, 16,  3,   147,
                             12,  233, 135, 162, 58,  118, 184, 237, 90,  105, 156, 195,
                             196, 104, 138, 19,  82,  62,  126, 140, 220, 171, 206, 232,
                             105, 123, 2,   135, 137, 41,  26,  219, 167, 245, 104, 103,
                             24,  144, 141, 210, 208, 114, 169, 170, 22,  11,  69,  106,
                             236, 150, 57,  184, 75,  241, 28,  175, 178, 186, 190, 124,
                             187, 116, 112, 162, 214, 154, 207, 31,  43,  40,  15,  188,
                             81,  197, 20,  199, 246, 132, 159, 111, 79,  95,  148, 184,
                             171, 173, 203, 146, 150, 33,  178, 9,   141, 49,  237, 222,
                             72,  5,   23,  38,  248, 82,  93,  229, 70,  180, 149, 232,
                             245, 72,  196, 138, 4,   31,  160, 30,  8,   109, 153, 252,
                             204, 126, 15,  182, 145, 130, 179, 234, 21,  240, 144, 105,
                             77,  116, 155, 232, 168, 99,  159, 92,  251, 223, 119, 173,
                             166, 39,  228, 91,  34,  5,   62,  172, 131, 164, 143, 10,
                             161, 165, 221, 214, 178, 110, 185, 254, 152, 149, 46,  144,
                             173, 237, 76,  210, 221, 45,  200, 113, 58,  20,  47,  135,
                             228, 80,  91,  51,  238, 194, 222, 231, 174, 244, 139, 96,
                             71,  25,  25,  62,  172, 181, 71,  27,  86,  0,   121, 38,
                             199, 236, 93,  158}),
                    {},
                    {},
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 64, 2, 2}, dtype::Float32(),
                            {1.10687500e+03,  9.59500000e+02,  8.98125000e+02,
                             1.21912500e+03,  1.38846378e+01,  3.91629181e+01,
                             -1.50343018e+02, -1.02085358e+02, 2.34341068e+01,
                             -8.40960388e+01, -4.23510742e+01, 1.72630596e+01,
                             -4.66624413e+01, -4.87857285e+01, -7.06332016e+01,
                             6.31493912e+01,  -9.96249924e+01, 7.72499924e+01,
                             7.46250153e+01,  5.81250114e+01,  -9.07061768e+01,
                             -7.68266630e+00, -3.15778809e+01, -3.35406876e+01,
                             8.55864143e+00,  -7.36760712e+01, 6.20557327e+01,
                             -2.92043419e+01, -1.39985870e+02, 2.56675129e+01,
                             5.21866226e+01,  1.07624054e+02,  -6.16851950e+00,
                             -8.56008530e+01, 7.35654449e+01,  -2.56767311e+01,
                             -2.09981880e+01, -6.22950821e+01, -1.31617493e+02,
                             -6.30962448e+01, -2.21552780e+02, -4.79528542e+01,
                             1.04179153e+02,  7.45253448e+01,  3.19730816e+01,
                             1.24306192e+01,  -9.93905945e+01, -8.95680237e+01,
                             -1.44870041e+02, -9.44738235e+01, -4.09417763e+01,
                             4.50356903e+01,  -3.65339231e+00, 5.79474449e+01,
                             -2.46253452e+01, 3.29394951e+01,  -1.09065903e+02,
                             5.23808861e+01,  -1.00386992e+01, -7.92311325e+01,
                             -1.44292374e+01, 5.74285736e+01,  2.28798485e+01,
                             6.84826508e+01,  -1.49241837e+02, 9.35751495e+01,
                             -4.02763329e+01, -6.63586197e+01, 2.15622040e+02,
                             -7.83887939e+01, -8.06824951e+01, -2.51097183e+01,
                             1.58941059e+01,  -5.66967869e+00, -1.53566467e+02,
                             -4.33494377e+01, 8.12108078e+01,  1.21169144e+02,
                             2.14673615e+02,  -3.72018318e+01, 2.45811577e+01,
                             -1.27189613e+02, 4.98553581e+01,  -5.83694696e+00,
                             -4.80477619e+00, -2.24601650e+01, -5.02191353e+00,
                             5.16259460e+01,  1.07266571e+02,  -3.41748886e+01,
                             -5.44621315e+01, 6.25573196e+01,  -4.24649086e+01,
                             4.42625465e+01,  2.71147366e+01,  4.83264275e+01,
                             -6.99711227e+01, -1.00299120e+01, 1.33173111e+02,
                             2.48003254e+01,  -1.74687519e+01, 9.44530487e-01,
                             1.35930038e+02,  6.72219162e+01,  4.53297043e+01,
                             1.37072708e+02,  -7.73253784e+01, 6.12967606e+01,
                             9.78184891e+01,  3.63894577e+01,  -1.64039135e+01,
                             -6.67858887e+01, 5.27859840e+01,  -4.99117432e+01,
                             8.77927475e+01,  -5.86666260e+01, 3.86430244e+01,
                             2.17759323e+01,  8.34562683e+01,  3.06256886e+01,
                             1.61030369e+01,  8.11268158e+01,  1.36932516e+01,
                             -1.06112595e+02, -9.31621475e+01, 3.13674717e+01,
                             -4.90609503e+00, 7.96453857e+01,  -1.02625000e+02,
                             1.40000076e+01,  3.18749981e+01,  -1.08375000e+02,
                             -5.44420319e+01, -1.50944397e+02, 5.29974670e+01,
                             -1.44041641e+02, 4.86086197e+01,  -7.13610382e+01,
                             3.06417294e+01,  7.20477829e+01,  -6.95384140e+01,
                             1.25441925e+02,  -1.54897385e+01, 3.78566666e+01,
                             4.23749886e+01,  -3.37500000e+01, -9.96250000e+01,
                             -6.73750076e+01, 3.34241295e+01,  -6.24825974e+01,
                             1.76387348e+01,  -6.45708389e+01, 1.70728874e+01,
                             -5.73032570e+01, -1.71570969e+01, 1.84064590e+02,
                             4.17566071e+01,  7.08248520e+00,  -2.59306641e+01,
                             1.37766739e+02,  -2.16669798e+00, 6.03565750e+01,
                             6.84421844e+01,  6.19825096e+01,  -1.44220114e+01,
                             -3.12404213e+01, -2.50061111e+01, 6.73021851e+01,
                             2.52050266e+01,  -8.35850677e+01, -4.70746574e+01,
                             1.73889160e+01,  1.18955564e+01,  6.16792488e+00,
                             -3.29667168e+01, 4.55779572e+01,  -4.17868996e+00,
                             -9.40233841e+01, -9.77727051e+01, 1.74934635e+01,
                             5.25992851e+01,  1.23662634e+01,  5.26129305e-01,
                             4.69518929e+01,  -1.52657738e+01, 9.96897888e+01,
                             -9.51726151e+01, 9.99432602e+01,  -1.75949844e+02,
                             1.00472336e+02,  -5.89417953e+01, -1.72231483e+01,
                             1.89282093e+01,  -8.17851868e+01, 7.22908936e+01,
                             -9.06294174e+01, 2.46093607e+00,  -4.03946457e+01,
                             2.17710762e+01,  -5.62999649e+01, 4.77665749e+01,
                             -4.04248848e+01, 4.78787374e+00,  1.05557320e+02,
                             -4.60584450e+01, -7.33774490e+01, -4.25107193e+01,
                             1.71907139e+01,  -8.01314316e+01, 1.69647141e+01,
                             -8.24824219e+01, 8.29206543e+01,  3.72900200e+01,
                             3.77470016e+01,  6.70151443e+01,  1.79784470e+01,
                             -4.01441078e+01, 6.29196739e+01,  7.60664597e+01,
                             -5.59005699e+01, 8.81600475e+00,  -6.89491081e+00,
                             -8.03825378e+01, -5.33856511e-01, 7.26196136e+01,
                             -3.76809120e+01, -1.08401566e+02, 6.35455990e+00,
                             -8.66767120e+01, -1.02679443e+02, -9.54313660e+00,
                             -3.55650787e+01, -1.21355652e+02, 2.32628040e+01,
                             3.94072838e+01,  1.24754738e+02,  9.51344986e+01,
                             -5.84752541e+01, -4.65028038e+01, 6.00556993e+00,
                             4.94889374e+01,  7.64868622e+01,  -1.49546280e+01,
                             -3.70648766e+01, 5.55572205e+01,  -1.17196434e+02,
                             9.20216217e+01,  3.29843826e+01,  3.25113411e+01,
                             5.62059135e+01,  6.30202141e+01,  4.99030991e+01,
                             2.85804024e+01,  -1.44606361e+01, 7.64952774e+01,
                             -2.95697536e+01})});
}

TEST_F(NAIVE, DCT_INT8) {
    Checker<DctChannelSelectForward> checker(
            handle(),
            /* check_dispatch */ false);
    DctChannelSelectForward::Param param;
    param.format = DctChannelSelectForward::Param::Format::NCHW4;
    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {1, 1, 16, 16}, dtype::Uint8(),
                            {113, 223, 229, 159, 249, 252, 89,  84,  45,  16,  41,  72,
                             184, 236, 70,  184, 86,  172, 218, 211, 47,  177, 18,  85,
                             174, 226, 37,  109, 38,  135, 228, 195, 133, 238, 47,  246,
                             244, 118, 175, 143, 34,  10,  28,  4,   82,  103, 89,  55,
                             235, 78,  151, 178, 249, 62,  183, 84,  105, 0,   121, 98,
                             249, 90,  161, 114, 121, 241, 21,  199, 196, 119, 231, 209,
                             250, 180, 192, 213, 116, 105, 114, 169, 1,   142, 3,   30,
                             140, 245, 201, 109, 19,  26,  224, 68,  123, 228, 64,  150,
                             184, 212, 136, 172, 241, 152, 222, 233, 15,  72,  130, 144,
                             107, 130, 242, 79,  195, 46,  226, 57,  183, 36,  88,  161,
                             121, 170, 2,   215, 109, 212, 35,  18,  76,  197, 117, 81,
                             208, 8,   237, 75,  15,  20,  16,  192, 61,  113, 96,  126,
                             211, 57,  49,  62,  185, 211, 155, 87,  233, 163, 164, 84,
                             61,  28,  1,   11,  190, 253, 145, 30,  38,  98,  153, 56,
                             231, 152, 12,  204, 96,  8,   47,  87,  25,  237, 21,  150,
                             173, 19,  41,  175, 164, 231, 39,  145, 39,  187, 210, 123,
                             165, 98,  87,  242, 38,  136, 182, 145, 41,  47,  147, 171,
                             172, 35,  170, 148, 26,  89,  107, 151, 130, 232, 65,  217,
                             27,  206, 68,  219, 60,  106, 3,   209, 175, 189, 191, 32,
                             119, 141, 56,  48,  105, 58,  94,  163, 185, 60,  83,  249,
                             112, 245, 137, 60,  178, 51,  177, 106, 199, 209, 4,   247,
                             3,   127, 88,  46}),
                    {},
                    {},
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 16, 2, 2, 4}, dtype::QuantizedS8(10.f),
                            {122, -1,  -8,  4,   92,  -13, -5,  7,   99,  4,   5,   3,
                             89,  7,   2,   -6,  3,   -8,  -10, 2,   -1,  0,   4,   -3,
                             -5,  -8,  -11, 1,   14,  4,   -10, -18, 3,   12,  -14, -2,
                             -4,  -9,  12,  4,   -2,  -2,  2,   6,   -9,  6,   1,   5,
                             -5,  -1,  2,   -12, 4,   -5,  -0,  4,   1,   5,   -8,  5,
                             -3,  4,   2,   6,   -0,  9,   -4,  -7,  -4,  -5,  -2,  8,
                             2,   4,   0,   7,   -8,  4,   -2,  3,   -6,  -5,  19,  5,
                             -4,  -4,  -5,  -16, -8,  -3,  -5,  19,  4,   3,   4,   -6,
                             1,   -12, -1,  7,   11,  -5,  -1,  -8,  2,   -12, -9,  -2,
                             -4,  -20, -11, -15, -15, -9,  -2,  -9,  -2,  -3,  13,  2,
                             5,   6,   7,   -4,  1,   -7,  6,   4,   2,   6,   0,   -0,
                             8,   8,   -6,  5,   1,   -2,  -2,  -12, 2,   -12, -2,  6,
                             7,   3,   4,   14,  14,  -3,  1,   -3,  6,   0,   -20, 2,
                             -10, 10,  -5,  -5,  13,  0,   -3,  7,   -12, -17, -13, 1,
                             -6,  10,  -1,  -9,  4,   -16, 3,   2,   5,   1,   -4,  9,
                             -0,  1,   3,   15,  -4,  -13, -6,  4,   3,   -2,  -1,  -4,
                             -7,  -7,  -2,  8,   -16, -4,  -10, 5,   1,   -3,  2,   -9,
                             -4,  1,   -1,  -1,  -4,  -6,  -4,  1,   0,   -9,  15,  -1,
                             -7,  -3,  -5,  -0,  3,   -0,  -6,  -17, 16,  -3,  3,   -2,
                             -3,  5,   3,   -2,  3,   13,  8,   1,   -3,  -8,  -7,  -4,
                             6,   -6,  -15, -7,  0,   4,   -3,  -3,  -10, 14,  1,   3,
                             14,  4,   -1,  14})});
}

TEST_F(NAIVE, DCT_INT8_MASK) {
    Checker<DctChannelSelectForward> checker(
            handle(),
            /* check_dispatch */ false);
    DctChannelSelectForward::Param param;
    param.format = DctChannelSelectForward::Param::Format::NCHW4;
    auto src_tensor = TensorValue(
            {1, 3, 8, 16}, dtype::Uint8(),
            {195, 165, 82,  30,  154, 60,  175, 195, 179, 165, 132, 37,  250, 107, 36,
             80,  5,   54,  247, 218, 191, 211, 239, 76,  140, 33,  253, 85,  132, 101,
             105, 177, 46,  183, 102, 99,  19,  175, 108, 252, 42,  238, 48,  251, 108,
             90,  176, 2,   35,  46,  161, 252, 38,  225, 195, 174, 58,  165, 198, 249,
             162, 118, 198, 41,  154, 10,  87,  24,  201, 12,  188, 1,   93,  179, 246,
             134, 18,  178, 173, 36,  122, 89,  115, 46,  43,  205, 232, 55,  149, 30,
             206, 97,  186, 125, 35,  209, 51,  48,  222, 222, 130, 173, 63,  0,   223,
             19,  5,   162, 154, 143, 134, 63,  123, 102, 102, 212, 145, 80,  87,  212,
             42,  26,  219, 225, 120, 94,  213, 238,

             25,  172, 141, 45,  182, 203, 50,  94,  44,  88,  74,  76,  151, 105, 138,
             87,  125, 55,  60,  211, 15,  158, 198, 37,  54,  203, 239, 79,  56,  6,
             53,  201, 97,  233, 178, 74,  193, 46,  249, 65,  5,   208, 130, 67,  191,
             168, 152, 129, 253, 195, 231, 3,   109, 229, 254, 193, 229, 202, 108, 22,
             89,  251, 13,  53,  47,  192, 12,  81,  19,  53,  93,  104, 41,  217, 215,
             184, 136, 249, 14,  244, 4,   220, 33,  53,  142, 219, 43,  28,  68,  198,
             202, 88,  235, 7,   233, 47,  84,  127, 28,  17,  189, 135, 183, 192, 239,
             116, 31,  118, 186, 49,  251, 233, 220, 27,  97,  30,  43,  193, 217, 48,
             24,  225, 15,  3,   26,  71,  82,  104,

             175, 125, 79,  195, 50,  236, 114, 179, 180, 177, 230, 173, 43,  195, 123,
             111, 106, 5,   91,  254, 34,  76,  52,  82,  193, 179, 185, 71,  57,  215,
             18,  5,   151, 13,  59,  206, 154, 95,  149, 40,  229, 16,  116, 144, 249,
             67,  97,  223, 208, 144, 92,  174, 246, 77,  196, 211, 20,  123, 239, 250,
             235, 65,  184, 54,  239, 168, 135, 17,  79,  117, 171, 173, 109, 39,  57,
             13,  129, 79,  236, 117, 134, 123, 149, 113, 198, 160, 249, 242, 220, 226,
             44,  113, 164, 217, 46,  249, 182, 22,  98,  228, 49,  78,  101, 236, 181,
             5,   245, 72,  62,  182, 151, 210, 254, 190, 35,  73,  190, 247, 50,  81,
             49,  217, 86,  229, 139, 203, 57,  194});
    checker.set_param(param).exect(
            Testcase{
                    src_tensor,
                    TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
                    TensorValue(
                            {32}, dtype::Int32(),
                            {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
                             0, 1, 8, 16, 9, 2, 3, 10, 0,  1,  8,  16, 9,  2,  3, 10}),
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 8, 1, 2, 4}, dtype::QuantizedS8(10.f),
                            {100, -12, 7,  7,  104, 2,  -2, -2, -7, -7,  -3, 8,  12,
                             -12, -5,  -1, 5,  -7,  -1, 7,  -7, -3, 6,   7,  -0, -2,
                             -7,  11,  6,  3,  -1,  7,  94, -5, 6,  -5,  98, 0,  -3,
                             -16, 5,   7,  13, -8,  1,  5,  -5, -8, 108, -3, -8, -7,
                             110, 1,   -2, 5,  -0,  7,  8,  -9, 14, -0,  1,  -4})});

    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {1, 3, 8, 16}, dtype::Uint8(),
                            {195, 165, 82,  30,  154, 60,  175, 195, 179, 165, 132, 37,
                             250, 107, 36,  80,  5,   54,  247, 218, 191, 211, 239, 76,
                             140, 33,  253, 85,  132, 101, 105, 177, 46,  183, 102, 99,
                             19,  175, 108, 252, 42,  238, 48,  251, 108, 90,  176, 2,
                             35,  46,  161, 252, 38,  225, 195, 174, 58,  165, 198, 249,
                             162, 118, 198, 41,  154, 10,  87,  24,  201, 12,  188, 1,
                             93,  179, 246, 134, 18,  178, 173, 36,  122, 89,  115, 46,
                             43,  205, 232, 55,  149, 30,  206, 97,  186, 125, 35,  209,
                             51,  48,  222, 222, 130, 173, 63,  0,   223, 19,  5,   162,
                             154, 143, 134, 63,  123, 102, 102, 212, 145, 80,  87,  212,
                             42,  26,  219, 225, 120, 94,  213, 238,

                             25,  172, 141, 45,  182, 203, 50,  94,  44,  88,  74,  76,
                             151, 105, 138, 87,  125, 55,  60,  211, 15,  158, 198, 37,
                             54,  203, 239, 79,  56,  6,   53,  201, 97,  233, 178, 74,
                             193, 46,  249, 65,  5,   208, 130, 67,  191, 168, 152, 129,
                             253, 195, 231, 3,   109, 229, 254, 193, 229, 202, 108, 22,
                             89,  251, 13,  53,  47,  192, 12,  81,  19,  53,  93,  104,
                             41,  217, 215, 184, 136, 249, 14,  244, 4,   220, 33,  53,
                             142, 219, 43,  28,  68,  198, 202, 88,  235, 7,   233, 47,
                             84,  127, 28,  17,  189, 135, 183, 192, 239, 116, 31,  118,
                             186, 49,  251, 233, 220, 27,  97,  30,  43,  193, 217, 48,
                             24,  225, 15,  3,   26,  71,  82,  104,

                             175, 125, 79,  195, 50,  236, 114, 179, 180, 177, 230, 173,
                             43,  195, 123, 111, 106, 5,   91,  254, 34,  76,  52,  82,
                             193, 179, 185, 71,  57,  215, 18,  5,   151, 13,  59,  206,
                             154, 95,  149, 40,  229, 16,  116, 144, 249, 67,  97,  223,
                             208, 144, 92,  174, 246, 77,  196, 211, 20,  123, 239, 250,
                             235, 65,  184, 54,  239, 168, 135, 17,  79,  117, 171, 173,
                             109, 39,  57,  13,  129, 79,  236, 117, 134, 123, 149, 113,
                             198, 160, 249, 242, 220, 226, 44,  113, 164, 217, 46,  249,
                             182, 22,  98,  228, 49,  78,  101, 236, 181, 5,   245, 72,
                             62,  182, 151, 210, 254, 190, 35,  73,  190, 247, 50,  81,
                             49,  217, 86,  229, 139, 203, 57,  194}),
                    TensorValue({4}, dtype::Int32(), {0, 12, 20, 28}),
                    TensorValue(
                            {28}, dtype::Int32(),
                            {0, 1,  8, 16, 9, 2,  3, 10, 17, 24, 32, 25, 0, 1,
                             8, 16, 9, 2,  3, 10, 0, 1,  8,  16, 9,  2,  3, 10}),
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 7, 1, 2, 4}, dtype::QuantizedS8(10.f),
                            {100, -12, 7,  7,  104, 2,  -2, -2,  -7,  -7, -3, 8,
                             12,  -12, -5, -1, 5,   -7, -1, 7,   -7,  -3, 6,  7,

                             94,  -5,  6,  -5, 98,  0,  -3, -16, 5,   7,  13, -8,
                             1,   5,   -5, -8, 108, -3, -8, -7,  110, 1,  -2, 5,
                             -0,  7,   8,  -9, 14,  -0, 1,  -4})});
}

TEST_F(NAIVE, DCT_4x4) {
    Checker<DctChannelSelectForward> checker(
            handle(),
            /* check_dispatch */ false);
    DctChannelSelectForward::Param param;
    param.dct_block_size = 4;
    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {1, 1, 8, 8}, dtype::Uint8(),
                            {186, 120, 112, 220, 69,  80,  201, 127, 246, 254, 175,
                             50,  240, 251, 76,  37,  34,  166, 250, 195, 231, 139,
                             128, 233, 75,  80,  3,   2,   19,  140, 193, 203, 115,
                             107, 250, 209, 14,  243, 199, 60,  234, 107, 174, 156,
                             81,  87,  13,  116, 96,  140, 197, 253, 113, 223, 229,
                             159, 249, 252, 89,  84,  45,  16,  41,  72}),
                    {},
                    {},
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 16, 2, 2}, dtype::Float32(),
                            {5.42000000e+02,  5.91750000e+02,  6.78000000e+02,
                             4.27750000e+02,  3.49953423e+01,  -1.17686939e+01,
                             -1.66842098e+01, -3.85316620e+01, -3.80000000e+01,
                             -1.22500000e+01, 2.00000000e+01,  -9.77500000e+01,
                             -1.61191311e+01, -9.46695328e+00, 3.28882408e+01,
                             -4.92537880e+01, 1.66958221e+02,  -4.26609573e+01,
                             2.56999969e-01,  5.39384537e+01,  1.71819706e+01,
                             9.00009003e+01,  -1.23818558e+02, 1.18912420e+01,
                             6.61014938e+01,  -2.49261990e+01, 4.95798302e+00,
                             -1.02324417e+02, 7.85859919e+00,  3.73140755e+01,
                             1.03783745e+02,  -4.61430321e+01, -1.43000000e+02,
                             -7.57500000e+01, -5.00000000e-01, -8.27500000e+01,
                             1.34834738e+01,  -1.93409515e+02, 6.84791718e+01,
                             -4.01652241e+00, 1.22000000e+02,  -8.57500000e+01,
                             -4.05000000e+01, -5.62500000e+01, -2.88564739e+01,
                             5.76532059e+01,  -2.67414131e+01, 1.70877876e+01,
                             3.85416756e+01,  3.09300461e+01,  5.84670639e+00,
                             1.85747864e+02,  -2.05141403e+02, -9.91859360e+01,
                             -1.66716263e+02, -1.71430378e+01, 6.71520996e+00,
                             8.41980438e+01,  -3.50666313e+01, -1.48387482e+02,
                             1.08180256e+01,  5.49991112e+01,  -1.06814528e+01,
                             1.86087704e+01})});

    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {1, 1, 8, 8}, dtype::Uint8(),
                            {186, 120, 112, 220, 69,  80,  201, 127, 246, 254, 175,
                             50,  240, 251, 76,  37,  34,  166, 250, 195, 231, 139,
                             128, 233, 75,  80,  3,   2,   19,  140, 193, 203, 115,
                             107, 250, 209, 14,  243, 199, 60,  234, 107, 174, 156,
                             81,  87,  13,  116, 96,  140, 197, 253, 113, 223, 229,
                             159, 249, 252, 89,  84,  45,  16,  41,  72}),
                    TensorValue({2}, dtype::Int32(), {0, 6}),
                    TensorValue({6}, dtype::Int32(), {0, 1, 8, 4, 2, 3}),
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 6, 2, 2}, dtype::Float32(),
                            {5.4200000e+02,  5.9175000e+02,  6.7800000e+02,
                             4.2775000e+02,  3.4995342e+01,  -1.1768694e+01,
                             -1.6684210e+01, -3.8531662e+01, -1.4300000e+02,
                             -7.5750000e+01, -5.0000000e-01, -8.2750000e+01,
                             1.6695822e+02,  -4.2660957e+01, 2.5699997e-01,
                             5.3938454e+01,  -3.8000000e+01, -1.2250000e+01,
                             2.0000000e+01,  -9.7750000e+01, -1.6119131e+01,
                             -9.4669533e+00, 3.2888241e+01,  -4.9253788e+01})});
}

TEST_F(NAIVE, DCT_WITH_MASK) {
    Checker<DctChannelSelectForward> checker(
            handle(),
            /* check_dispatch */ false);
    DctChannelSelectForward::Param param;
    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {1, 3, 8, 16}, dtype::Uint8(),
                            {109, 39,  30,  115, 71,  15,  206, 139, 221, 5,   18,  16,
                             93,  185, 99,  102, 205, 172, 191, 29,  185, 6,   47,  84,
                             0,   47,  105, 203, 251, 73,  196, 83,  3,   211, 32,  181,
                             49,  111, 114, 83,  148, 232, 77,  17,  35,  2,   154, 100,
                             41,  135, 141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                             78,  65,  184, 69,  91,  82,  2,   172, 194, 240, 49,  145,
                             87,  210, 97,  190, 179, 93,  125, 105, 181, 207, 148, 178,
                             133, 53,  25,  198, 238, 151, 14,  120, 213, 195, 145, 20,
                             122, 107, 217, 185, 65,  5,   115, 110, 82,  206, 163, 86,
                             2,   2,   44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                             238, 181, 232, 191, 161, 57,  23,  204,

                             109, 39,  30,  115, 71,  15,  206, 139, 221, 5,   18,  16,
                             93,  185, 99,  102, 205, 172, 191, 29,  185, 6,   47,  84,
                             0,   47,  105, 203, 251, 73,  196, 83,  3,   211, 32,  181,
                             49,  111, 114, 83,  148, 232, 77,  17,  35,  2,   154, 100,
                             41,  135, 141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                             78,  65,  184, 69,  91,  82,  2,   172, 194, 240, 49,  145,
                             87,  210, 97,  190, 179, 93,  125, 105, 181, 207, 148, 178,
                             133, 53,  25,  198, 238, 151, 14,  120, 213, 195, 145, 20,
                             122, 107, 217, 185, 65,  5,   115, 110, 82,  206, 163, 86,
                             2,   2,   44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                             238, 181, 232, 191, 161, 57,  23,  204,

                             109, 39,  30,  115, 71,  15,  206, 139, 221, 5,   18,  16,
                             93,  185, 99,  102, 205, 172, 191, 29,  185, 6,   47,  84,
                             0,   47,  105, 203, 251, 73,  196, 83,  3,   211, 32,  181,
                             49,  111, 114, 83,  148, 232, 77,  17,  35,  2,   154, 100,
                             41,  135, 141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                             78,  65,  184, 69,  91,  82,  2,   172, 194, 240, 49,  145,
                             87,  210, 97,  190, 179, 93,  125, 105, 181, 207, 148, 178,
                             133, 53,  25,  198, 238, 151, 14,  120, 213, 195, 145, 20,
                             122, 107, 217, 185, 65,  5,   115, 110, 82,  206, 163, 86,
                             2,   2,   44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                             238, 181, 232, 191, 161, 57,  23,  204}),
                    TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
                    TensorValue(
                            {32}, dtype::Int32(),
                            {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
                             0, 1, 8, 16, 9, 2, 3, 10, 0,  1,  8,  16, 9,  2,  3, 10}),
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 32, 1, 2}, dtype::Float32(),
                            {890.12494,  941.25,     -7.0498576,  99.47632,
                             -22.850792, -97.862236, -101.043236, -4.727012,
                             28.275675,  -157.96654, 42.1377,     45.06531,
                             -149.77373, 24.487143,  -8.054966,   -13.990831,
                             -6.9395194, -3.9211385, 64.79172,    -12.363858,
                             -47.875,    59.,        56.271786,   -62.725567,
                             120.522675, 16.559765,  85.74334,    112.904495,
                             99.375,     29.499973,  2.0220923,   -19.681704,
                             890.12494,  941.25,     -7.0498576,  99.47632,
                             -22.850792, -97.862236, -101.043236, -4.727012,
                             28.275675,  -157.96654, 42.1377,     45.06531,
                             -149.77373, 24.487143,  -8.054966,   -13.990831,
                             890.12494,  941.25,     -7.0498576,  99.47632,
                             -22.850792, -97.862236, -101.043236, -4.727012,
                             28.275675,  -157.96654, 42.1377,     45.06531,
                             -149.77373, 24.487143,  -8.054966,   -13.990831})});
    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {1, 3, 8, 16}, dtype::Uint8(),
                            {109, 39,  30,  115, 71,  15,  206, 139, 221, 5,   18,  16,
                             93,  185, 99,  102, 205, 172, 191, 29,  185, 6,   47,  84,
                             0,   47,  105, 203, 251, 73,  196, 83,  3,   211, 32,  181,
                             49,  111, 114, 83,  148, 232, 77,  17,  35,  2,   154, 100,
                             41,  135, 141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                             78,  65,  184, 69,  91,  82,  2,   172, 194, 240, 49,  145,
                             87,  210, 97,  190, 179, 93,  125, 105, 181, 207, 148, 178,
                             133, 53,  25,  198, 238, 151, 14,  120, 213, 195, 145, 20,
                             122, 107, 217, 185, 65,  5,   115, 110, 82,  206, 163, 86,
                             2,   2,   44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                             238, 181, 232, 191, 161, 57,  23,  204,

                             109, 39,  30,  115, 71,  15,  206, 139, 221, 5,   18,  16,
                             93,  185, 99,  102, 205, 172, 191, 29,  185, 6,   47,  84,
                             0,   47,  105, 203, 251, 73,  196, 83,  3,   211, 32,  181,
                             49,  111, 114, 83,  148, 232, 77,  17,  35,  2,   154, 100,
                             41,  135, 141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                             78,  65,  184, 69,  91,  82,  2,   172, 194, 240, 49,  145,
                             87,  210, 97,  190, 179, 93,  125, 105, 181, 207, 148, 178,
                             133, 53,  25,  198, 238, 151, 14,  120, 213, 195, 145, 20,
                             122, 107, 217, 185, 65,  5,   115, 110, 82,  206, 163, 86,
                             2,   2,   44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                             238, 181, 232, 191, 161, 57,  23,  204,

                             109, 39,  30,  115, 71,  15,  206, 139, 221, 5,   18,  16,
                             93,  185, 99,  102, 205, 172, 191, 29,  185, 6,   47,  84,
                             0,   47,  105, 203, 251, 73,  196, 83,  3,   211, 32,  181,
                             49,  111, 114, 83,  148, 232, 77,  17,  35,  2,   154, 100,
                             41,  135, 141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                             78,  65,  184, 69,  91,  82,  2,   172, 194, 240, 49,  145,
                             87,  210, 97,  190, 179, 93,  125, 105, 181, 207, 148, 178,
                             133, 53,  25,  198, 238, 151, 14,  120, 213, 195, 145, 20,
                             122, 107, 217, 185, 65,  5,   115, 110, 82,  206, 163, 86,
                             2,   2,   44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                             238, 181, 232, 191, 161, 57,  23,  204}),
                    TensorValue({4}, dtype::Int32(), {0, 8, 16, 24}),
                    TensorValue({24}, dtype::Int32(), {17, 24, 32, 25, 18, 11, 4, 5,
                                                       0,  1,  8,  16, 9,  2,  3, 10,
                                                       0,  1,  8,  16, 9,  2,  3, 10}),
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 24, 1, 2}, dtype::Float32(),
                            {-6.9395194, -3.9211385, 64.79172,    -12.363858,
                             -47.875,    59.,        56.271786,   -62.725567,
                             120.522675, 16.559765,  85.74334,    112.904495,
                             99.375,     29.499973,  2.0220923,   -19.681704,
                             890.12494,  941.25,     -7.0498576,  99.47632,
                             -22.850792, -97.862236, -101.043236, -4.727012,
                             28.275675,  -157.96654, 42.1377,     45.06531,
                             -149.77373, 24.487143,  -8.054966,   -13.990831,
                             890.12494,  941.25,     -7.0498576,  99.47632,
                             -22.850792, -97.862236, -101.043236, -4.727012,
                             28.275675,  -157.96654, 42.1377,     45.06531,
                             -149.77373, 24.487143,  -8.054966,   -13.990831})});
}

TEST_F(NAIVE, DCT_WITH_FIX_32_MASK) {
    Checker<DctChannelSelectForward> checker(
            handle(),
            /* check_dispatch */ false);
    using Param = DctChannelSelectForward::Param;
    Param param;
    param.fastImpl = Param::FastImpl::FIX_32_MASK;
    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {1, 3, 8, 16}, dtype::Uint8(),
                            {109, 39,  30,  115, 71,  15,  206, 139, 221, 5,   18,  16,
                             93,  185, 99,  102, 205, 172, 191, 29,  185, 6,   47,  84,
                             0,   47,  105, 203, 251, 73,  196, 83,  3,   211, 32,  181,
                             49,  111, 114, 83,  148, 232, 77,  17,  35,  2,   154, 100,
                             41,  135, 141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                             78,  65,  184, 69,  91,  82,  2,   172, 194, 240, 49,  145,
                             87,  210, 97,  190, 179, 93,  125, 105, 181, 207, 148, 178,
                             133, 53,  25,  198, 238, 151, 14,  120, 213, 195, 145, 20,
                             122, 107, 217, 185, 65,  5,   115, 110, 82,  206, 163, 86,
                             2,   2,   44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                             238, 181, 232, 191, 161, 57,  23,  204,

                             109, 39,  30,  115, 71,  15,  206, 139, 221, 5,   18,  16,
                             93,  185, 99,  102, 205, 172, 191, 29,  185, 6,   47,  84,
                             0,   47,  105, 203, 251, 73,  196, 83,  3,   211, 32,  181,
                             49,  111, 114, 83,  148, 232, 77,  17,  35,  2,   154, 100,
                             41,  135, 141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                             78,  65,  184, 69,  91,  82,  2,   172, 194, 240, 49,  145,
                             87,  210, 97,  190, 179, 93,  125, 105, 181, 207, 148, 178,
                             133, 53,  25,  198, 238, 151, 14,  120, 213, 195, 145, 20,
                             122, 107, 217, 185, 65,  5,   115, 110, 82,  206, 163, 86,
                             2,   2,   44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                             238, 181, 232, 191, 161, 57,  23,  204,

                             109, 39,  30,  115, 71,  15,  206, 139, 221, 5,   18,  16,
                             93,  185, 99,  102, 205, 172, 191, 29,  185, 6,   47,  84,
                             0,   47,  105, 203, 251, 73,  196, 83,  3,   211, 32,  181,
                             49,  111, 114, 83,  148, 232, 77,  17,  35,  2,   154, 100,
                             41,  135, 141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                             78,  65,  184, 69,  91,  82,  2,   172, 194, 240, 49,  145,
                             87,  210, 97,  190, 179, 93,  125, 105, 181, 207, 148, 178,
                             133, 53,  25,  198, 238, 151, 14,  120, 213, 195, 145, 20,
                             122, 107, 217, 185, 65,  5,   115, 110, 82,  206, 163, 86,
                             2,   2,   44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                             238, 181, 232, 191, 161, 57,  23,  204}),
                    TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
                    TensorValue(
                            {32}, dtype::Int32(),
                            {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
                             0, 1, 8, 16, 9, 2, 3, 10, 0,  1,  8,  16, 9,  2,  3, 10}),
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 32, 1, 2}, dtype::Float32(),
                            {890.12494,  941.25,     -7.0498576,  99.47632,
                             -22.850792, -97.862236, -101.043236, -4.727012,
                             28.275675,  -157.96654, 42.1377,     45.06531,
                             -149.77373, 24.487143,  -8.054966,   -13.990831,
                             -6.9395194, -3.9211385, 64.79172,    -12.363858,
                             -47.875,    59.,        56.271786,   -62.725567,
                             120.522675, 16.559765,  85.74334,    112.904495,
                             99.375,     29.499973,  2.0220923,   -19.681704,
                             890.12494,  941.25,     -7.0498576,  99.47632,
                             -22.850792, -97.862236, -101.043236, -4.727012,
                             28.275675,  -157.96654, 42.1377,     45.06531,
                             -149.77373, 24.487143,  -8.054966,   -13.990831,
                             890.12494,  941.25,     -7.0498576,  99.47632,
                             -22.850792, -97.862236, -101.043236, -4.727012,
                             28.275675,  -157.96654, 42.1377,     45.06531,
                             -149.77373, 24.487143,  -8.054966,   -13.990831})});
}

TEST_F(NAIVE, DCT_WITH_MASK2) {
    Checker<DctChannelSelectForward> checker(handle(), false);
    DctChannelSelectForward::Param param;
    UniformIntRNG rng_oc(0, 3 * 64);
    for (size_t n : {1, 3}) {
        for (size_t ic : {1, 3}) {
            for (size_t ih : {8, 16, 32, 512, 1024}) {
                for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
                    int random_oc = static_cast<int>(rng_oc.gen_single_val());
                    int max_oc = ic * 64;
                    int mask_oc = (random_oc % max_oc) + 1;
                    auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param);
                    checker.set_param(param).exect(
                            test_case->testcase_in, test_case->testcase_out);
                }
            }
        }
    }
}

}  // namespace test
}  // namespace megdnn

// vim: syntax=cpp.doxygen